# @File    : wraper.py
import enum
import os
import random
import sys
import threading
import time
from typing import Union

import numpy
import pandas
import scipy.interpolate

import ezmath

MAGLAB_DIR = os.path.join(os.path.split(__file__)[0], "..")
os.add_dll_directory(os.path.join(MAGLAB_DIR, r"core\import\runtime_x86"))
DEFAULT_CFG_DIR = os.path.join(os.path.split(__file__)[0], "config")
sys.path.append(os.path.join(MAGLAB_DIR, r"core\Debug"))
import libhallmachine

hm_cnt = 1  # HallMachine对象计数，只允许一个对象在运行


class HeaderNotMatchException(RuntimeError):
    def __init__(self, columns):
        super(HeaderNotMatchException, self).__init__("列名不匹配，预期的标题行：\n%s" % columns)


class Vec3:
    def __init__(self, x=None, y=None, z=None):
        self.x, self.y, self.z = x, y, z

    def from_libhallmachine_Vec3(self, vec3: libhallmachine.Vec3d):
        self.__init__(vec3.x, vec3.y, vec3.z)


def vec3d_to_numpy(v: Vec3):
    return numpy.array([v.x, v.y, v.z])


def numpy_to_vec3(arr: numpy.ndarray):
    return Vec3(*arr)


class Axis(enum.Enum):
    Z = 0
    X = 1
    Y = 2
    A = 3
    C = 4
    Z2 = 5  # Z轴的激光尺读数


class UMAC_Pos(libhallmachine.Vec3d):
    """
    记录UMAC位置信息的数据结构
    """

    def __init__(self, x=0, y=0, z=0, a=0, c=0, z2=None):
        super(UMAC_Pos, self).__init__(x, y, z)
        self.a = a
        self.c = c
        self.z2 = z2 if z2 is not None else z
        self.d_get_pos_by_axis = {
            Axis.X: lambda: self.x,
            Axis.Y: lambda: self.y,
            Axis.Z: lambda: self.z,
            Axis.A: lambda: self.a,
            Axis.C: lambda: self.c,
            Axis.Z2: lambda: self.z2,
        }

    def to_Vec3d(self) -> libhallmachine.Vec3d:
        """
        获取只包含x, y, z信息的Vec3d对象
        :return:
        """
        return super(UMAC_Pos, self)

    def __str__(self):
        return "{x: %f, y: %f, z: %f, z2: %f, a: %f, c: %f}" % (self.x, self.y, self.z, self.z2, self.a, self.c)

    def get_pos_by_axis(self, axis: Axis):
        """
        获取指定轴的坐标值
        :param axis:
        :return:
        """
        return self.d_get_pos_by_axis[axis]()

    @staticmethod
    def from_UMACData(umac_data: libhallmachine.UMACData):
        """
        返回一个新的对象
        :param umac_data:
        :return:
        """
        return UMAC_Pos(umac_data.x_pos_mm, umac_data.y_pos_mm, umac_data.z_pos_mm,
                        umac_data.a_ang_deg, umac_data.c_ang_deg, umac_data.z2_pos_mm)

    def from_UMACData_inplace(self, umac_data: libhallmachine.UMACData):
        self.__init__(umac_data.x_pos_mm, umac_data.y_pos_mm, umac_data.z_pos_mm,
                      umac_data.a_ang_deg, umac_data.c_ang_deg, umac_data.z2_pos_mm)


def add_prefix(origin_path: str, prefix: str) -> str:
    dir, filename = os.path.split(origin_path)
    path = os.path.join(dir, prefix + filename)
    return path


def append_n_rows_with_nan(df: pandas.DataFrame, n) -> pandas.DataFrame:
    """
    df后新增n行，用nan填充
    注意：不会改变原始的df
    :param n:
    :return:
    """
    new_df = pandas.concat(
        [df, pandas.DataFrame(numpy.nan * numpy.ones([n, df.shape[1]]), columns=df.columns)],
        ignore_index=True)
    return new_df


def columns_is_match(df, columns: list):
    return len(df.columns) == len(columns) and numpy.all(df.columns == columns)


class RecorderBase:
    """
    有很多是冗余的，主要为了画图方便。
    """
    columns = "timestamp,x_mm,y_mm,z_mm,s_mm,A_deg,C_deg,Bx_Gauss,By_Gauss,Bz_Gauss,B_Gauss,I_A,Vx_V,Vy_V,Vz_V,V_V,remain_records,tag".split(
        ",")
    k_timestamp, k_x_mm, k_y_mm, k_z_mm, k_s_mm, k_A_deg, k_C_deg, \
    k_Bx_Gauss, k_By_Gauss, k_Bz_Gauss, k_B_Gauss, k_I_A, \
    k_Vx_V, k_Vy_V, k_Vz_V, k_V_V, \
    k_remain_records, k_tag = \
        columns
    columns_xyz = [k_x_mm, k_y_mm, k_z_mm]
    columns_not_support_mean = [k_tag]  # 不支持求平均的列

    def __init__(self, autogen_path_method=None):
        """
        :param autogen_path_method: 自动生成保存路径的方法，无参数，return str
        """
        print("Create a recorder")
        self.df = pandas.DataFrame(columns=self.columns)
        self.autogen_path_method = autogen_path_method

    def __len__(self):
        return self.df.shape[0]

    def s(self) -> numpy.ndarray:
        """
        生成从起点开始的路程序列
        :return:
        """
        i = 0
        s = [0]
        for j in range(self.df.shape[0]):
            """"""
            s.append(s[-1] + self.distance(i, j))
            i = j
        return numpy.array(s)

    def ds(self) -> numpy.ndarray:
        i = 0
        ds = []
        for j in range(self.df.shape[0]):
            ds.append(self.distance(i, j))
            i = j
        return numpy.array(ds)

    def distance(self, i, j, ):
        df = self.df[RecorderBase.columns_xyz]
        pt1 = df.loc[i]
        pt2 = df.loc[j]
        delta = pt2 - pt1
        return numpy.linalg.norm(delta)

    def read_csv(self):
        df = pandas.read_csv(self.autogen_path_method())
        if not self.columns_is_match(df):
            raise HeaderNotMatchException(self.columns)
        self.df = df

    def to_csv(self):
        path = self.autogen_path_method()
        self.df.to_csv(path, index=False)
        print("数据已写入 %s" % path)

    @staticmethod
    def columns_is_match(df):
        return columns_is_match(df, RecorderBase.columns)

    def mean_last_n(self, n, ):
        if n <= 1:
            return
        new_df = pandas.concat([
            self.df.loc[:self.df.shape[0] - 1 - n],
            pandas.DataFrame(self.df.loc[self.df.shape[0] - n:].mean()).T
        ], ignore_index=True)
        self.df = new_df

    def append_n_rows_with_nan(self, n):
        """
        新增n行，用nan填充
        :param n:
        :return:
        """
        self.df = append_n_rows_with_nan(self.df, n)

    def integrate_value_along(self, columns_to_be_integrated: Union[list, numpy.ndarray], X: str):
        """
        计算某几列值沿x的积分
        :param X: 可以是x，y，z，s，t等
        :param columns_to_be_integrated:
        :return: a new RecordBase object with only 1 row as integrated result
        """
        sum_ = numpy.array(len(columns_to_be_integrated) * [0.])
        for j in range(1, self.df.shape[0]):
            dX = (self.df[X][j] - self.df[X][j - 1])
            values = (self.df[columns_to_be_integrated].loc[j - 1]).values
            sum_ = sum_ + dX * values
        res = RecorderBase(None)
        res.append_n_rows_with_nan(1)
        res.df[columns_to_be_integrated] = sum_
        return res

    def format(self):
        '''
        调整为正确的格式
        :return:
        '''
        columns_to_int = {self.k_remain_records}
        columns_to_str = set(self.columns_not_support_mean)
        columns_to_float = set(self.columns) - columns_to_int - columns_to_str
        self.df[list(columns_to_int)] = self.df[list(columns_to_int)].astype(int)
        self.df[list(columns_to_float)] = self.df[list(columns_to_float)].astype(float)
        self.df[list(columns_to_str)] = self.df[list(columns_to_str)].astype(str)


class DefaultFilePath:
    prefix_mean = "mean."
    __VBackground = "VBackground.csv"
    raw_VBackground = os.path.join(DEFAULT_CFG_DIR, __VBackground)
    mean_VBackground = add_prefix(raw_VBackground, prefix_mean)
    __VtoB = "VtoB.csv"
    __raw_VtoB = os.path.join(DEFAULT_CFG_DIR, __VtoB)
    mean_VtoB = add_prefix(__raw_VtoB, prefix_mean)

    @staticmethod
    def raw_V_to_B_componet(axis: Axis) -> str:
        filename, ext = os.path.splitext(DefaultFilePath.__raw_VtoB)
        filename = filename + "_%s" % axis
        return filename + ext


class HallMachineDataProcessor:
    def __init__(self, cfg_dir=DEFAULT_CFG_DIR):
        self.__cfg_dir = cfg_dir
        os.makedirs(self.__cfg_dir, exist_ok=True)
        self.__recorder_V_background = RecorderBase(
            autogen_path_method=lambda: DefaultFilePath.mean_VBackground
        )
        try:
            self.__recorder_V_background.read_csv()
        except (HeaderNotMatchException, FileNotFoundError) as e:
            """
            生成不太准但凑合能用的初始数据
            """
            print("正在生成备用数据")
            n = 1
            self.__recorder_V_background.append_n_rows_with_nan(n)
            values = numpy.array(n * [6 * [0]])
            for i in range(n):
                values[i][:3] = i
            self.__recorder_V_background.df[
                [self.__recorder_V_background.k_x_mm, self.__recorder_V_background.k_y_mm,
                 self.__recorder_V_background.k_z_mm, self.__recorder_V_background.k_Vx_V,
                 self.__recorder_V_background.k_Vy_V, self.__recorder_V_background.k_Vz_V]] = values
        self.__recorder_V_to_B = RecorderBase(
            autogen_path_method=lambda: DefaultFilePath.mean_VtoB
        )
        try:
            self.__recorder_V_to_B.read_csv()
            print("Use V to B data: %s\n%s" % (self.__recorder_V_to_B.autogen_path_method(), self.__recorder_V_to_B.df))
        except(HeaderNotMatchException, FileNotFoundError) as e:
            """
            生成不太准但凑合能用的初始数据
            """
            n = 4
            self.__recorder_V_to_B.append_n_rows_with_nan(n)
            f_B = lambda v: 200 * v
            v = numpy.array(
                3 * [[10 * i for i in range(n)]]).T
            self.__recorder_V_to_B.df[[RecorderBase.k_Vx_V, RecorderBase.k_Vy_V, RecorderBase.k_Vz_V, ]] = numpy.array(
                v)
            self.__recorder_V_to_B.df[
                [RecorderBase.k_Bx_Gauss, RecorderBase.k_By_Gauss, RecorderBase.k_Bz_Gauss]] = f_B(v)
        __xyz = self.__recorder_V_background.df[RecorderBase.columns_xyz]
        self.__polator_V_background = Vec3(
            ezmath.Polator(__xyz, self.__recorder_V_background.df[RecorderBase.k_Vx_V]),
            ezmath.Polator(__xyz, self.__recorder_V_background.df[RecorderBase.k_Vy_V]),
            ezmath.Polator(__xyz, self.__recorder_V_background.df[RecorderBase.k_Vz_V]),
        )
        __interp_method = "cubic"
        self.__polator_V_to_B = Vec3(
            scipy.interpolate.UnivariateSpline(
                self.__recorder_V_to_B.df[RecorderBase.k_Vx_V].dropna().astype(float),
                self.__recorder_V_to_B.df[RecorderBase.k_Bx_Gauss].dropna().astype(float),
            ),
            scipy.interpolate.UnivariateSpline(
                self.__recorder_V_to_B.df[RecorderBase.k_Vy_V].dropna().astype(float),
                self.__recorder_V_to_B.df[RecorderBase.k_By_Gauss].dropna().astype(float),
            ),
            scipy.interpolate.UnivariateSpline(
                self.__recorder_V_to_B.df[RecorderBase.k_Vz_V].dropna().astype(float),
                self.__recorder_V_to_B.df[RecorderBase.k_Bz_Gauss].dropna().astype(float),
            ),
        )

    def __get_V_background_by_numpy(self, loc_: numpy.ndarray) -> numpy.ndarray:
        """
        :param loc_:
        :return:
        """
        return numpy.array([
            self.__polator_V_background.x(loc_),
            self.__polator_V_background.y(loc_),
            self.__polator_V_background.z(loc_)]
        ).ravel()

    def get_V_relative_by_numpy(self, V: numpy.ndarray, loc: numpy.ndarray) -> numpy.ndarray:
        V_back = self.__get_V_background_by_numpy(loc)
        # print("Background @ %s is %s" % (pos, V_back))
        return V - V_back

    def __get_V_background(self, loc: libhallmachine.Vec3d) -> numpy.ndarray:
        loc_ = vec3d_to_numpy(loc)
        return self.__get_V_background_by_numpy(loc_)

    def __get_B_by_numpy(self, V_relative_V: numpy.ndarray) -> numpy.ndarray:
        """
        :param V_relative_V: 相对电压，即减去本底电压后的读数
        :return:
        """
        return numpy.array([
            self.__polator_V_to_B.x(V_relative_V[0]),
            self.__polator_V_to_B.y(V_relative_V[1]),
            self.__polator_V_to_B.z(V_relative_V[2])]
        ).ravel()

    def calculate_B_by_numpy(self, V: numpy.ndarray, pos: numpy.ndarray) -> numpy.ndarray:
        """
        计算磁场
        :param V: 此处万用表的原始电压
        :param pos: 此处位置
        :return:
        """
        V_reative = self.get_V_relative_by_numpy(V, pos)
        return self.__get_B_by_numpy(V_reative)


class HallMachineWraper(libhallmachine.HallMachine):
    """
    对C++对象HallMachine进行封装，以完善数据处理逻辑（如，根据万用表读数计算磁感应强度）
    The wrapper of libhallmachine.HallMachine
    """

    def __init__(self, cfg_path: str = DEFAULT_CFG_DIR):
        global hm_cnt
        if hm_cnt > 1:
            raise RuntimeError("已有一个HallMachine对象在运行！")
        print("新建了一个HallMachine对象，内存中此类型对象的个数：%d" % hm_cnt)
        hm_cnt = hm_cnt + 1

        super(HallMachineWraper, self).__init__()
        # 当前位置的磁感应强度
        self.B_ = Vec3(0, 0, 0)
        self._data_processor = HallMachineDataProcessor(cfg_path)

    def set_to(self, x, y, z):
        print("Jog to %.2f, %.2f, %.2f" % (x, y, z))
        self.z_jog_to(z)
        self.x_jog_to(x)
        self.y_jog_to(y)
        #     TODO: 电流

    def set_delta(self, dx, dy, dz):
        self.z_jog(dz)
        self.x_jog(dx)
        self.y_jog(dy)
        #     TODO: 电流

    def get_V_relative(self):
        return self._data_processor.get_V_relative_by_numpy(vec3d_to_numpy(self.V_),
                                                            vec3d_to_numpy(UMAC_Pos.from_UMACData(self.umac_data_)))

    def update_B(self):
        """
        :return:
        """
        self.update_GPIB()
        self.B_.x, self.B_.y, self.B_.z = self._data_processor.calculate_B_by_numpy(
            vec3d_to_numpy(self.V_),
            vec3d_to_numpy(UMAC_Pos.from_UMACData(self.umac_data_)))

    def update_all(self):
        self.update_UMAC()
        self.update_B()
        if HallMachineAxisHelper(self).all_axis_stopped():
            pass

    def calculate_B_norm(self):
        return numpy.linalg.norm(vec3d_to_numpy(self.B_))


class FakeHallMachine(HallMachineWraper):
    """
    For test
    和HallMachine具有相同的接口，用于产生假数据以便验证一些功能
    """
    DEFAULT_SPEED = 10

    def __init__(self):
        super(FakeHallMachine, self).__init__()
        print("正在使用FakeHallMachine，它和HallMachine具有相同的接口，用于产生假数据以便验证一些功能")
        self.umac_data_.manual = False
        self.dt = 2e-3
        self.B_ = Vec3(0, 0, 0)

    def open(self):
        pass


    def fake_latency(self):
        """
        模拟网络延时
        :return:
        """
        time.sleep(1e-2)

    def fake_V_numpy(self, xyz) -> numpy.ndarray:
        k = 1e-3
        uncertainty_xyz = numpy.random.randn(3) * k
        ks = numpy.array([0, 0, k / 1e3])
        return xyz * ks + uncertainty_xyz

    def x_jog_to(self, arg0):
        print("x_jog_to %s" % arg0)

        def func():
            self.umac_data_.x_stoped = False
            s = 0
            s_target = arg0 - self.umac_data_.x_pos_mm
            if s_target == 0:
                self.umac_data_.x_stoped = True
                return
            direction = 1 if s_target > 0 else -1
            while True:
                ds = direction * self.DEFAULT_SPEED * self.dt
                s += ds
                self.umac_data_.x_pos_mm += ds
                time.sleep(self.dt)
                if abs(s) >= abs(s_target):
                    self.umac_data_.x_stoped = True
                    return

        threading.Thread(target=func, ).start()

    def z_jog_to(self, arg0):
        print("z_jog_to %s" % arg0)

        def func():
            self.umac_data_.z_stoped = False
            s = 0
            s_target = arg0 - self.umac_data_.z_pos_mm
            if s_target == 0:
                self.umac_data_.z_stoped = True
                return
            direction = 1 if s_target > 0 else -1
            while True:
                ds = direction * self.DEFAULT_SPEED * self.dt
                s += ds
                self.umac_data_.z_pos_mm += ds
                time.sleep(self.dt)
                if abs(s) >= abs(s_target):
                    self.umac_data_.z_stoped = True
                    return

        threading.Thread(target=func, ).start()

    def y_jog_to(self, arg0):
        print("y_jog_to %s" % arg0)

        def func():
            self.umac_data_.y_stoped = False
            s = 0
            s_target = arg0 - self.umac_data_.y_pos_mm
            if s_target == 0:
                self.umac_data_.y_stoped = True
                return
            direction = 1 if s_target > 0 else -1
            while True:
                ds = direction * self.DEFAULT_SPEED * self.dt
                s += ds
                self.umac_data_.y_pos_mm += ds
                time.sleep(self.dt)
                if abs(s) >= abs(s_target):
                    self.umac_data_.y_stoped = True
                    return

        threading.Thread(target=func, ).start()

    def update_UMAC(self):
        self.umac_data_.z_limit_pos_minus = bool(random.randint(0, 3))
        self.fake_latency()

    def update_GPIB(self):
        self.V_.x, self.V_.y, self.V_.z = self.fake_V_numpy(
            vec3d_to_numpy(UMAC_Pos.from_UMACData(self.umac_data_).to_Vec3d()))
        self.fake_latency()

    def update_B(self):
        self.update_GPIB()
        self.B_.x, self.B_.y, self.B_.z = self._data_processor.calculate_B_by_numpy(
            vec3d_to_numpy(self.V_),
            vec3d_to_numpy(UMAC_Pos.from_UMACData(self.umac_data_)))

    def update_all(self):
        self.update_UMAC()
        self.update_B()


# TODO: 调试时可选择Fake
HallMachine = HallMachineWraper# FakeHallMachine  # HallMachineWraper
UMACData = libhallmachine.UMACData


class HallMachineAxisHelper:
    """
    用于简化HallMachine对象涉及轴选取的操作
    如，要设定未知的axis轴速度为s，只需调用
    HallMachineAxisHelper(hm).d_set_speed[axis](s)
    其中hm是HallMachine对象
    """

    def __init__(self, hm: HallMachine):
        self.d_set_speed = {
            Axis.X: hm.set_x_speed,
            Axis.Y: hm.set_y_speed,
            Axis.Z: hm.set_z_speed,
        }
        self.d_jog_to = {
            Axis.X: hm.x_jog_to,
            Axis.Y: hm.y_jog_to,
            Axis.Z: hm.z_jog_to,
        }
        self.d_stop = {
            Axis.X: hm.x_stop,
            Axis.Y: hm.y_stop,
            Axis.Z: hm.z_stop,
        }
        self.d_get_speed = {
            Axis.X: lambda: hm.umac_data_.x_speed,
            Axis.Y: lambda: hm.umac_data_.y_speed,
            Axis.Z: lambda: hm.umac_data_.z_speed,
        }
        self.d_get_B = {
            Axis.X: lambda: hm.B_.x,
            Axis.Y: lambda: hm.B_.y,
            Axis.Z: lambda: hm.B_.z,
        }
        self.d_get_V = {
            Axis.X: lambda: hm.V_.x,
            Axis.Y: lambda: hm.V_.y,
            Axis.Z: lambda: hm.V_.z,
        }
        self.d_get_V_relative = {
            Axis.X: lambda: hm.get_V_relative()[0],
            Axis.Y: lambda: hm.get_V_relative()[1],
            Axis.Z: lambda: hm.get_V_relative()[2],
        }
        self.d_find_0 = {
            Axis.X: hm.x_find_zero,
            Axis.Y: hm.y_find_zero,
            Axis.Z: hm.z_find_zero
        }
        self.d_axis_stoped = {
            Axis.X: lambda: hm.umac_data_.x_stoped,
            Axis.Y: lambda: hm.umac_data_.y_stoped,
            Axis.Z: lambda: hm.umac_data_.z_stoped,
            Axis.Z2: lambda: hm.umac_data_.z2_stoped,
            Axis.A: lambda: hm.umac_data_.a_stoped,
            Axis.C: lambda: hm.umac_data_.c_stoped,
        }
        self.hm = hm

    def get_pos(self, axis: Axis):
        return UMAC_Pos.from_UMACData(self.hm.umac_data_).d_get_pos_by_axis[axis]()

    def all_axis_stopped(self):
        """判断是否所有轴都已经停止"""
        res = True
        for key in self.d_axis_stoped:
            res = res and self.d_axis_stoped[key]()
        return res


if __name__ == "__main__":
    hm = HallMachine()
    hm.update_all()
    hm2 = HallMachine()
